{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6b92be5-6a8e-47b4-91fa-c8cf0b912227",
   "metadata": {},
   "source": [
    "## KNN分类cifar-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "de313295-8af5-4708-8d8a-4b7221fba1da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from skimage.feature import hog\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "\n",
    "def unpickle(file):\n",
    "    import pickle\n",
    "\n",
    "    with open(file, \"rb\") as fo:\n",
    "        dict = pickle.load(fo, encoding=\"bytes\")\n",
    "    return dict\n",
    "\n",
    "\n",
    "def rgb2gray(im):\n",
    "    gray = im[:, :, 0] * 0.2989 + im[:, :, 1] * 0.5870 + im[:, :, 2] * 0.1140\n",
    "    return gray"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27eb0291-73d5-4f9f-aa07-562ac55f90b1",
   "metadata": {},
   "source": [
    "### 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "0bc9d921-22b6-4a1a-9e5c-8e85b8473f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Path(\"/Users/yanhui/Downloads/cifar-10-batches-py/\")\n",
    "train1_path = data_path / \"data_batch_1\"\n",
    "train2_path = data_path / \"data_batch_2\"\n",
    "train3_path = data_path / \"data_batch_3\"\n",
    "train4_path = data_path / \"data_batch_4\"\n",
    "train5_path = data_path / \"data_batch_5\"\n",
    "test_path = data_path / \"test_batch\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d941028b-1ca0-4431-b8ff-102a915c8b9b",
   "metadata": {},
   "source": [
    "#### 准备batch1数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "def460cd-d37f-4ecf-87f3-71026a1c296d",
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = unpickle(train1_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "de2275c2-af7e-4bab-bd58-5e4acdc83e8b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 3072)"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t1[b\"data\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "64378dd4-0b99-419c-806a-c20e808c9b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "t1_data = t1[b\"data\"]\n",
    "t1_labels = t1[b\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "3da12f3c-fbf2-4c1b-a14e-85d7141c2979",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 10000)"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(t1_data), len(t1_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa15355d-471f-413c-a732-24fe3740c403",
   "metadata": {},
   "source": [
    "#### 查看原图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "13f3ae18-d84e-462e-8e3e-5192591eb34e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x13897eee0>"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbeUlEQVR4nO2dbWxc93XmnzMzfBVFURQlWZacMnG9dbJ5UQzWTes0dRykcAMXdtrASBYbeIEgKhYJsAG6H4wssMku+iEtmgT5sEihxEbdIs1L89J4C28ax03iOGnt0I4syZZtyRb1ZooiJVF8GXJeTz/MqGG08xzSQ3KG0v/5AYKG88y999z/vXPmzn3mnL+5O4QQ6ZJpdwBCiPaiJCBE4igJCJE4SgJCJI6SgBCJoyQgROK0JQmY2Z1m9qKZHTOz+9sRw5JYxszskJkdMLPRFm/7QTM7Z2aHlzw3aGaPmtnR+v9b2xjLp83sTH1sDpjZ+1oQxw1m9kMze97MnjOz/1Z/vuXjEsTSjnHpNrOnzOzZeiz/q/78683syfp76etm1vmaV+7uLf0HIAvgZQBvANAJ4FkAb2p1HEviGQMw1KZtvwvALQAOL3nuLwDcX398P4A/b2Msnwbw31s8JrsA3FJ/vBnASwDe1I5xCWJpx7gYgL764w4ATwJ4B4BvAPhg/fm/AvBfX+u623ElcCuAY+7+irsXAXwNwN1tiKPtuPvjAC5c8fTdAB6qP34IwD1tjKXluPu4uz9TfzwL4AiA3WjDuASxtByvMVf/s6P+zwHcAeCb9eebGpd2JIHdAE4t+fs02jSwdRzA983saTPb18Y4LrPT3cfrj88C2NnOYAB83MwO1r8utOSryWXMbBjA21H71GvruFwRC9CGcTGzrJkdAHAOwKOoXVFPu3u5/pKm3ku6MQi8091vAfAHAD5mZu9qd0CX8do1Xjt/1/1FADcC2AtgHMBnW7VhM+sD8C0An3D3maVaq8elQSxtGRd3r7j7XgB7ULuivnkt1tuOJHAGwA1L/t5Tf64tuPuZ+v/nAHwHtcFtJxNmtgsA6v+fa1cg7j5RP/GqAL6EFo2NmXWg9qb7irt/u/50W8alUSztGpfLuPs0gB8C+G0AA2aWq0tNvZfakQR+DuCm+l3NTgAfBPBwG+KAmW0ys82XHwP4fQCH46XWnYcB3Fd/fB+A77YrkMtvujrvRwvGxswMwAMAjrj755ZILR8XFkubxmW7mQ3UH/cAeC9q9yh+COAD9Zc1Ny6tvMO55E7n+1C70/oygP/RjhjqcbwBNXfiWQDPtToWAF9F7XKyhNr3uY8A2AbgMQBHAfwAwGAbY/lbAIcAHETtTbirBXG8E7VL/YMADtT/va8d4xLE0o5xeSuAX9S3eRjA/1xyDj8F4BiAvwfQ9VrXbfUVCSESRTcGhUgcJQEhEkdJQIjEURIQInGUBIRInLYlgQ3yE10AioWhWBpzrcXSziuBDTOQUCwMxdKYayoWfR0QInFW9WMhM7sTwBdQ6xHwZXf/TPT6oaEhHx4eBgBMTk5i+/btTW97LVEsjbkylmq1Sl9bLpeplstlqeZVfv5lMr/8jJqamsLQ0NC//20Zo8vVSu/J9ppa6lfZyMeIMTY2hqmpqYa7mGv05EowsyyA/4Pab5hPA/i5mT3s7s+zZYaHhzE62rh5T3SCiTUmeCfUfi7fmIX5PNXOX5ii2uAgr7StFBep1tPbS7VsZxfV3PgFbjV4q/NUdfVz6628xmk1XwfUHESIa4DVJIGN1hxECNEE635j0Mz2mdmomY1OTk6u9+aEEK+R1SSBFTUHcff97j7i7iMb5WaKEOKXNH1jEEuag6D25v8ggP/U7MqW3gkWG5NC/hLVLpx+hWqnjvDlLs3MU+22O95Dtf6ebqpFn20W3BhM9QxsOgm4e9nMPg7gn1C7sfqguz+3ZpEJIVrCaq4E4O6PAHhkjWIRQrSBVK+AhBB1lASESBwlASESR0lAiMRZ1Y3BtURdj1tHNNYZ49rZU8epdvBfHqdaaYHXHHT08bqChRluLfYPDlItqg+I6gpSPQN1JSBE4igJCJE4SgJCJI6SgBCJoyQgROIoCQiROBvGIozaWom1xcFbuZUK3M579dQJqvX39lCtd2Az1c5dnKXa+fH/rzL939l5w+uohkzQ05AvtUzfwmsXXQkIkThKAkIkjpKAEImjJCBE4igJCJE4SgJCJM6GsQjF2tJspeDkhfNUGxs7SbVCsNzm7k6q5edmqPbCs7+g2nXDN1Jt4Lpg+otgXKJC1mvZwtaVgBCJoyQgROIoCQiROEoCQiSOkoAQiaMkIETiyCK8ZomssArVzpw+TbXjJ7l26hifi3Bocx/V9gxtotr4SV61eGj051QbuX2Aar39W6gW9Ce9pllVEjCzMQCzACoAyu4+shZBCSFax1pcCbzb3afWYD1CiDagewJCJM5qk4AD+L6ZPW1m+xq9wMz2mdmomY1OTk6ucnNCiLVmtUngne5+C4A/APAxM3vXlS9w9/3uPuLuI9u3b1/l5oQQa82qkoC7n6n/fw7AdwDcuhZBCSFaR9M3Bs1sE4CMu8/WH/8+gP/dfCi8+WXz3s06eD5BpZmHYrB/QYWaNZ2n+Tqr1TLVSuUS1Wbzi1Q7PXGBahOBVqnsoNqeHXzfX/j5U1Tbcd0uqv2H34w+p/jbIePBMYq6lwaHL1glLDpf1pjVuAM7AXynXmKZA/B37v69NYlKCNEymk4C7v4KgLetYSxCiDYgi1CIxFESECJxlASESBwlASESZwNVEUY+S7NrbNIijEIJG1UGGrgtF9qAoX0YaRFcfd3wMNV6N/dTbWZ+Idgc37/Dp85RrSfXRbXcYpFqz/3sx1Tbtnsn1bbueQPVrMyPrQVeX3QOVjNBM9i1fzvwbbVuU0KIjYiSgBCJoyQgROIoCQiROEoCQiSOkoAQibOBLMK1z0dhdVdAZPWhyrVq0MCzVOaWVmcnn6vPwp2IrKlosSyVtm4doto733U71Q4deIFqY8d5w9BKmY/ZsexZqnUPX8/X+eJRqh368U+p9lt/yPtd9PTyZqmVqBow0riEcpOWObONmyx0FEKkgJKAEImjJCBE4igJCJE4SgJCJI6SgBCJs3EswrDrYrPrjKr6gqqwYJVl59WAR49xa2phYZ5qN7/xjVTr6uJ2XibynwKqztdZDU6J37ntd6l28vgZqn35r75MtfICt05PTk5TrauXVxjeNMg/2178ySjVtgdVhDffxhuU5oMK0Y4qj6UzOH4X8peoVigWqMYs12KJL6MrASESR0lAiMRREhAicZQEhEgcJQEhEkdJQIjEWdYiNLMHAdwF4Jy7v7n+3CCArwMYBjAG4F53v7iaQKqBnRcV0oXNPStBc88o/QXWzakzJ6n2fx/5R6rNzHDL53emeLPNd//eHVTr6uI2WTSe0Sx35QpX+zZvptpdd99FtWMvvkS1H/y/R6k2U+LH74UzvMJwq/VQrXuRH/h//d73qZbbxqsIMzsHqDY/zY97R5VXUI7PnKbapVm+zsXFxvNFzuVn6DIruRL4awB3XvHc/QAec/ebADxW/1sIcRWybBJw98cBXDmt7N0AHqo/fgjAPWsblhCiVTR7T2Cnu4/XH59FbYZiIcRVyKpvDHrtSzn9Ampm+8xs1MxGJycnV7s5IcQa02wSmDCzXQBQ/5/e2XL3/e4+4u4j27fz9k1CiPbQbBJ4GMB99cf3Afju2oQjhGg1K7EIvwrgdgBDZnYawKcAfAbAN8zsIwBOALh39aFwuyTy8y5ePE+1SxevvJ+5ZJVZbgOeneSW3b+MPkW1p597lmozF6apVijxSrr/+JY3U23Hdt4UNJvlh3ZmNk+16elpqg3v2UO16/fsoNp/+eh/ptqpMy9T7clnD1KtMM8rIY+e5vZh73V8ufOHD1Mt/20q4cbbbqHaxblZvs7AtivYNNWiisAqaYQbNbpdNgm4+4eI9J7llhVCbHz0i0EhEkdJQIjEURIQInGUBIRIHCUBIRKnxY1GHUBje6MaVFRFnT8vzUxR7Sc/e4JqJ17lVVpTM9NUuzjPLZ/MJj6nYHdhE9XOnY/24SdUGx6+gWpRheGZ0/yXm6Uit5IW8tNUm5vlWkdwlr3xN3lzzwPHDlGtOMurJE9Pc+utt5OPy54t3VQ7PvoM1bJd/LM0c/0g1S6VuVXLjUwAzs+zQqHx+8uD0lFdCQiROEoCQiSOkoAQiaMkIETiKAkIkThKAkIkTkstwoXFPJ470rjSLpfroMtFttXFoOpteo43ZDw5zufO27JjG9UGt/AmltuGeL+EyZfHqXbkMLfCHv0Bb8S5pZ/Hks1xk6lQ5PZasdC4USUAfO+fuNYRfJxEFYa9Q/y4v23vzVT7xRMvUi0ftFJ96fwE1Xoq3MbdWuZNVo/969NUm97ObccLGR5nR5EvVw4asObzjW3H2ZkFuoyuBIRIHCUBIRJHSUCIxFESECJxlASESBwlASESp6UW4fz8HH721M8aagsz83S5Td3curnrrrupVnZeMfb0oReotmXzVqotVLlNdv0OPgdLaYJbNJfmeTVZ/ii3wrYG1WubtvAx69vKrczuTdy22jLAbcct/f1U6+/n8/j19PVS7fY7fotql6a4/Xv48CtUq5R4SerJ6cAC7eBWZu4st+xmL3KtvJlbvJke3kT2zCluN8+Q91FxkdvsuhIQInGUBIRIHCUBIRJHSUCIxFESECJxlASESJyVzEX4IIC7AJxz9zfXn/s0gI8CuNyx8pPu/shy6yoUinhlrLF9c+ncRbrcTa+/iWo9PdwKe/VVPqfgieMnqda3iVs3hRK38yyo1FqY5lYRMty2+vUbeSPOG7dvodrmrdyyO3eO22tbB/nnwq4b+FjPzvBx6QyaXHZXue3YH+zfe+98N9UuXOSNRidO83NiqsAD7b3E17kjsEdzxis2d2/mTUg37byOamfGxqhWzDduhOtBI9+VXAn8NYA7Gzz/eXffW/+3bAIQQmxMlk0C7v44AD69rxDiqmY19wQ+bmYHzexBM+M/sRNCbGiaTQJfBHAjgL0AxgF8lr3QzPaZ2aiZjebz/DuzEKI9NJUE3H3C3SvuXgXwJQC3Bq/d7+4j7j7S28tvuAkh2kNTScDMdi358/0ADq9NOEKIVrMSi/CrAG4HMGRmpwF8CsDtZrYXtckFxwD8yUo2Vq1UMH+psT2VX+RfFbp6edPFS7Pc7jpxaoxqA1u4rVOZ59Vktth4rjcAGD97jGuv8vkGLcPXee8f/xHVqnP8fu0/P/Ejqp04yJusbtvC57k7e5Rbmbuvfx3VLpV4c090cMtucBuvynzLb7yZasV7+Gn94AN/S7WFWX7cX52eoxpywdyARW47zk2dp9r1wfnZ2cMrGod2DDR8fuocPwbLJgF3/1CDpx9YbjkhxNWBfjEoROIoCQiROEoCQiSOkoAQiaMkIETitLTRaNWrKBYaW4H5Am80euw4t96+8w/fotoTP/4x1cy53TUxw+2gyROnqNYRVMuVgiquzut4tdxPH/8J1Qoz3HZ8/uhLVJuf4BWN05M8zoFt3KqdDJptzlzix3brAP8BWbHC9+FHP3qGaj39fC7JrUN8XsSpErfs8gW+f2cCa9G7+HnWG4xLdpJbpwPb+PmSzTZ+S798lDdf1ZWAEImjJCBE4igJCJE4SgJCJI6SgBCJoyQgROK01CLM5rLYMtjY3igF6Whmjjd5fP7AAapNHD9OtUyw6705XqXVmeEVY16M5nvjVtGeXbupNhjMi3gxaNLyhuHfoNqJCm/qOn2B22SVrgGqTQSVl/k8tx2nL/DqNsvyJqSLFuxD/mWqZTq5JVnNBse2k8eSB/eGK2WubQpi6dvCj3s2y98sVW881tlgLHUlIETiKAkIkThKAkIkjpKAEImjJCBE4igJCJE4rbUIs1n0EYswt5nPc1c8z6utpl7iVX039PFqKwusvtkFbnctZng1mfXwKrsu4xbN5ARvGPr0k89SbefmzVQ7f3GaapcWuLU4F1RCLkxxqxaBBZoLrLeeDj5X32JguU5OT1OtkuFj3Zvjtpxl+GdippuvE4FFCC9RaX6eH4eZYF7LrdsGglDYceDHR1cCQiSOkoAQiaMkIETiKAkIkThKAkIkjpKAEImzkrkIbwDwNwB2ojb34H53/4KZDQL4OoBh1OYjvNfdeWkXADeg2tk473iFWxidQdVUR4lXqL2uf5Bq5cBGmg0stGx/H9UyndwiXJjgcyYWpvM8lvOzVJuq8nGZLvB1Dt/yVqqdneRVhNMX+T709XGLdzHPLd5SBx+zxaC550KJ23KZDD+XuoNj5MbtvEpgA2Zz/G2UKXMLtFrl6zw3OU21Mj/lketsvO/lSjBefHW/XB7An7r7mwC8A8DHzOxNAO4H8Ji73wTgsfrfQoirjGWTgLuPu/sz9cezAI4A2A3gbgAP1V/2EIB71ilGIcQ68pruCZjZMIC3A3gSwE53H69LZ1H7uiCEuMpYcRIwsz4A3wLwCXf/ld+Purujdr+g0XL7zGzUzEbzc/y7thCiPawoCZhZB2oJ4Cvu/u360xNmtquu7wLQcMoUd9/v7iPuPtLbx3+3LYRoD8smATMzAA8AOOLun1siPQzgvvrj+wB8d+3DE0KsNyupIrwNwIcBHDKzA/XnPgngMwC+YWYfAXACwL3LrahSqWJ6urHlVcjzirFNRW7nbb/ueqqdP8Hnczs2doJqkyVeRTg4yG3HTDe/0pmvcve0UuKWVjlfoNpigXtFZePW1ORZPofh/By3Fr3E19nb1Uu1YlCVaV1dVCsv8n3v3MQtSQ/ssMUCP8+qGb5/xTJfrquDV0l2dvP96+vldnNPoJWC45BhlZB8keWTgLs/AV6H+J7llhdCbGz0i0EhEkdJQIjEURIQInGUBIRIHCUBIRKnpY1GUTVggczzx90glI1bMPNB/8fxoLnneDBH3FwxaBx5nlfSZTu4vZYPKsacNocEFsq8ks7JvHMA0BnYVmcmuUUYVZtZ0Kxy8mJQQGp8Oa/wfejo4ZZrfyffv0pQZlf7cWtjsjn+mdgDPj9lJqpyDY6DBfvgwfliwfYyRt7SwTHQlYAQiaMkIETiKAkIkThKAkIkjpKAEImjJCBE4rTUIjQz5Kyx1VIKrJu5Be4fXpjh8+NdKPLlyh18173MrcXFqCIuqFAredQYk29v05Z+qmWzfLmo+aUHqT+00KLtBVrU+DOY/g/VaG7AcN/5WFeqgX0YxRnuH4/TAmsOxperBnEGrjHKTAyOq64EhEgcJQEhEkdJQIjEURIQInGUBIRIHCUBIRKnpRZhtVLB3OxcQ21mhs9XNx/MVzA/H1h2gTvTP8Ctt64e3hwywgKrqCfHK8Y6Ovn2IuutI7A5I4uwElU0BlZS1K0yWiwb+YBBQ9RKUGFIrTDE+1AKlqsE+5fN8eOQi+zYIJbubj4vYldkYQf2YRdp3BpZlboSECJxlASESBwlASESR0lAiMRREhAicZQEhEicZS1CM7sBwN8A2ImaR7Tf3b9gZp8G8FEAk/WXftLdH4nWVS6XMXX+fEOtVOS2x+Iir84rFrnW0c2bQ3Z0c8tuYYFbklFTyagaEIHmHsxFWOGWViZqjNnLbcfIyoy8vshajIjsqah5aUQ+z5u6RtZiLrLegirCaMyi/Yst12Dfg8W6gzkvmUUYVTqu5HcCZQB/6u7PmNlmAE+b2aN17fPu/pcrWIcQYoOykglJxwGM1x/PmtkRALvXOzAhRGt4TfcEzGwYwNsBPFl/6uNmdtDMHjSzrWsdnBBi/VlxEjCzPgDfAvAJd58B8EUANwLYi9qVwmfJcvvMbNTMRguFYIYRIURbWFESMLMO1BLAV9z92wDg7hPuXnH3KoAvAbi10bLuvt/dR9x9hN20EEK0j2WTgNVufT4A4Ii7f27J87uWvOz9AA6vfXhCiPVmJe7AbQA+DOCQmR2oP/dJAB8ys72omRljAP5kuRVV3VEqEUsv6H6Zy3GrL7q46ArmsovcGTadGxBX9VUDW6cS2ICRpZUNrMVsZ9D8soOPZ2cwnpGlFcUZW2GcoCAutLUGBgaoViqVqFYILOVKUNHYrA0YVTuWyzxOVAIt8A/ZMaoEc0yuxB14Ao3fMuFvAoQQVwf6xaAQiaMkIETiKAkIkThKAkIkjpKAEInT0kajuVwO27Zta6hlwG2rSiVqHBnMOxdYPouLvFLQskE1WTh/HI+lGFg02WpQfRgQ25Xce4vGrNmqvqipazXwTstlHmc1OO5R48/IlosajZaqQcVmMNbN2ofh/I1N2IAAPwc9mguTKkKIJFASECJxlASESBwlASESR0lAiMRREhAicVpqEWazWfT3N54DsFqJmi7yXFUo8mqrmXzjeQ8BINcRVOcFWmTPIJA6goq4cmAtViM7KLABEViZFlQ0hqWQAdXACqsG9qgHn0PVwNYqLvBqwKiKsBp18AwajUajElnDHizZG8xF2BlYoJnAkmTzIkYVmboSECJxlASESBwlASESR0lAiMRREhAicZQEhEicllqEAGAk71hQ8Vcs8fkKFgu8GpA2NUVcFZYL7BQP7K5iUKFWCKrlrMk58CKrKLKEqmU+1k3OnIdolkIP4ozmN3QLKt9yfJ0dWV6RGhE5p3ED1sAejQY0quwLLN5ouXKp8XmmKkIhBEVJQIjEURIQInGUBIRIHCUBIRJHSUCIxFnWIjSzbgCPA+iqv/6b7v4pM3s9gK8B2AbgaQAfdnfuyQGA84qrQiGqCuNasbjItWCdxRK386LqtagRZ9Q4sjuYNDETVIxVAtsxsq2iyjYL5jeM9i+yHTuDfY9YXOTHL2oYmg1iiY5DNGaFArei8/mgMW1ggXYHlYLRPpSLPJbIPuzubnyeRTGu5EqgAOAOd38bgL0A7jSzdwD4cwCfd/dfB3ARwEdWsC4hxAZj2STgNS4X5nfU/zmAOwB8s/78QwDuWY8AhRDry4ruCZhZtj4t+TkAjwJ4GcC0u1++XjsNYPe6RCiEWFdWlATcveLuewHsAXArgJtXugEz22dmo2Y2urDAv1cJIdrDa3IH3H0awA8B/DaAATO7fGNxD4AzZJn97j7i7iM9PT2riVUIsQ4smwTMbLuZDdQf9wB4L4AjqCWDD9Rfdh+A765TjEKIdWQlVYS7ADxkZlnUksY33P0fzex5AF8zsz8D8AsADyy3InenTSAjGzCyihBYPqzpIgAgtMk44fxxUfVhUCkYzY8X7UPU9NSCesBsUGWXicalyTn3PLArOzs7g1j4eDZrLXZ08H1v9thGxyGKpZPYeQDQ29VLtej8ZMcosn6XTQLufhDA2xs8/wpq9weEEFcx+sWgEImjJCBE4igJCJE4SgJCJI6SgBCJY5G1s+YbM5sEcKL+5xCAqZZtPEaxNEaxNOZqjOXX3H17I6GlSeBXNmw26u4jbdn4FSiWxiiWxlxrsejrgBCJoyQgROK0Mwnsb+O2r0SxNEaxNOaaiqVt9wSEEBsDfR0QInGUBIRIHCUBIRJHSUCIxFESECJx/g3AEsheB+yLfQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "mat = t1_data[2].reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "plt.matshow(mat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed973c68-b5d7-4bc3-b4fb-af6aad8702fc",
   "metadata": {},
   "source": [
    "#### 尝试HOG特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "73e3e774-3417-4ca2-960a-1065b3a779f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x133c28dc0>"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVrUlEQVR4nO3df4xddZnH8fdzb6cz/TW0paUObaWUVhBRC5kF/BEXMRrWNQE3LpFNTP8w1jWykcTdyLLJiptNVjb+2Jps3NSlsRpWQdCFGLKKDUlj1q0MWEuhQgu20lpaaDttaTtl5t5n/7in7iyZ5zvTM/fHTL+fVzLpnfvce85zz8w8Pfc89/v9mrsjIvmqdDoBEeksFQGRzKkIiGRORUAkcyoCIplTERDJXEeKgJndZGbPmdluM7uzEzmMymWPmT1tZtvMbKDN+95oZofMbMeo+xaa2WNmtqv4d0EHc7nbzPYXx2abmX24DXksN7PHzexZM3vGzD5X3N/245LIpRPHpcfMfmlmvy5y+VJx/6VmtrX4W7rfzGae88bdva1fQBV4AVgJzAR+DVzZ7jxG5bMHWNShfb8PuAbYMeq+fwbuLG7fCdzTwVzuBv66zcekD7imuD0PeB64shPHJZFLJ46LAXOL213AVuB64AHg48X9/wZ85ly33YkzgWuB3e7+oru/DnwfuLkDeXScu28Bjrzh7puBTcXtTcAtHcyl7dz9gLs/Vdw+AewEltKB45LIpe284bXi267iy4EbgQeL+0sdl04UgaXAS6O+30eHDmzBgZ+a2ZNmtq6DeZy1xN0PFLdfBpZ0MhngdjPbXrxdaMtbk7PMbAVwNY3/9Tp6XN6QC3TguJhZ1cy2AYeAx2icUQ+6+0jxkFJ/S7owCO9192uAPwE+a2bv63RCZ3njHK+Tn+v+JnAZsAY4AHy1XTs2s7nAQ8Ad7n58dKzdx2WMXDpyXNy95u5rgGU0zqivaMZ2O1EE9gPLR32/rLivI9x9f/HvIeBHNA5uJx00sz6A4t9DnUrE3Q8Wv3h14Fu06diYWReNP7r73P2Hxd0dOS5j5dKp43KWuw8CjwPvAuab2YwiVOpvqRNF4AlgdXFVcybwceCRDuSBmc0xs3lnbwMfAnakn9VyjwBri9trgYc7lcjZP7rCR2nDsTEzA+4Fdrr710aF2n5colw6dFwWm9n84vYs4IM0rlE8DnyseFi549LOK5yjrnR+mMaV1heAv+tEDkUeK2l0J34NPNPuXIDv0TidHKbxfu6TwIXAZmAX8DNgYQdz+S7wNLCdxh9hXxvyeC+NU/3twLbi68OdOC6JXDpxXN4B/KrY5w7g70f9Dv8S2A38AOg+121bsSERyZQuDIpkTkVAJHMqAiKZUxEQyZyKgEjmOlYEpshHdAHlElEuYzvfcunkmcCUOZAol4hyGdt5lYveDohkblIfFjKzm4D1NOYI+Hd3/3Lq8TOt23uYA8AwZ+iiu/S+m0m5jE25jG065jLESV73MzZWrHQRMLMqjY/+fpDGx0yfAG5z92ej5/TaQr/OPlBqfyJS3lbfzHE/MmYRmMzbAU0OInIemEwRmGqTg4hICTPGf8jkFC2MdQA9zG717kTkHE3mTGBCk4O4+wZ373f3/qlyMUVE/s9kzgT+MDkIjT/+jwN/0ZSs5A+qixeHMZszK4zVX44n3qkPDU0qp3O1/wvvDmOXbHohjNVPngpjdnE8xWDtud0TS6xZbMzrbQ3TYKh+6SLg7iNmdjvwExotwo3u/kzTMhORtpjUNQF3fxR4tEm5iEgH6BODIplTERDJnIqASOZUBEQy1/IPC8kkVeL2k8+KP3fR7jZg9fJVYWzpPf8dxnb907vC2MqHToSx+kA81X+1tzeM1Y4fD2OlpdqAlWocq9ean0sJOhMQyZyKgEjmVAREMqciIJI5FQGRzKkIiGTu/G4RTvPRXdNJauReqn1YWxa3MkfmzgxjM9vdBizJUi1enxq/nzoTEMmcioBI5lQERDKnIiCSORUBkcypCIhkbnq0CFOtvlZssxXtmbKvYdGCeJOHB5u/v5TEcUmN3LNTcRtw2f0Lw9iBvzoZxpZ/+ZI4l+f2hrGUZGsxNRqwJJsZt0B9eKS5O0sMWNSZgEjmVAREMqciIJI5FQGRzKkIiGRORUAkc9OjRXg+jPgr+xrq9cQmE9ucQsds5KV9YWzWkvlhbPj5uO3oT/wi3mFq1GLZdQpLTgrq8Y8P6omfURsnIZ1UETCzPcAJGl3IEXfvb0ZSItI+zTgTeL+7v9qE7YhIB+iagEjmJlsEHPipmT1pZuvGeoCZrTOzATMbGObMJHcnIs022bcD73X3/WZ2EfCYmf3G3beMfoC7bwA2APTawqlztUpEgEmeCbj7/uLfQ8CPgGubkZSItE/pMwEzmwNU3P1EcftDwD80LbOMVHp6wpidjt9CearF1GapEXipiUY9MeJv9frZYWzvF94dxlJrH04pqTZgG0e5TubtwBLgR9ZIdgbwH+7+X03JSkTapnQRcPcXgXc2MRcR6QC1CEUypyIgkjkVAZHMqQiIZM6SI9GarNcW+nX2gbbtT0QatvpmjvuRMfuOOhMQyZyKgEjmVAREMqciIJI5FQGRzKkIiGRORUAkcyoCIplTERDJnIqASOZUBEQypyIgkjkVAZHMTYu1CCvz5sWxhfPDmB+LJ7+kuzsM1Q4emkhaTVO98i1hrL5rTxirzJ0TxvxMPEFp/dSpCeXVLLvWXx/GVt+xNYzZjK5S+/Ph10s9r6zKVVfEwRd/F4ba/XOI6ExAJHMqAiKZUxEQyZyKgEjmVAREMqciIJK5cVuEZrYR+AhwyN2vKu5bCNwPrAD2ALe6+9HJJFKZHa87Vz9xIs5veV+80e6ZYaj2/Ath7LVb45bW3Af+J95fwumb47VaZz38yzBWWXNlGLNXBsNY7Wj846i9/5owVn38qTCWsj+xNuDqz8VrA+76xnVh7PK/3RHG6idPhrHqhQvDWO3wkTCW8tqfx3nO/UGizflHb483+sTTpXKxq98WxvxXz5zz9iZyJvBt4KY33HcnsNndVwObi+9FZBoatwi4+xbgjeXzZmBTcXsTcEtz0xKRdil7TWCJux8obr9MY4ViEZmGJn1h0Burl4QrmJjZOjMbMLOBYeKPsopIZ5QtAgfNrA+g+Df8sL27b3D3fnfv7yL+vL6IdEbZIvAIsLa4vRZ4uDnpiEi7jbsWoZl9D7gBWAQcBL4I/CfwAPBmYC+NFuG4vZeyaxGm2of+1kvD2O//+IIwNm9fPYyVbQOWlWofDq6Ku7hzE69h9qF4JF3ZNmBZqfbhxr9cH8a+dMOfhTF/LW4Rlm0DlpVqHx69ohrG3vzjY6X2V6YNmFqLcNzPCbj7bUFIK4uKnAf0iUGRzKkIiGRORUAkcyoCIplTERDJ3LgtwmZKtQirb10dPu+lP10cxk5dHLfJ3vSL+LXN2/1aGKsMlZuosrZzVxiz/qvC2MHresPY0IXx/oYuHgljc34bN36WbU5MwJrgA/GovhnLl4Wx+sJ4otja7Hik508f2hTGVt33mTB2+TdeCmMpIy/tC2NDH4nbuEcvLzdfb8+Nr4Sxw7viH/yyn8W/85FtW9ZzYnDfmC1CnQmIZE5FQCRzKgIimVMREMmcioBI5lQERDI3ZdYiPPb2uCVy8VfiiSqrb7ks3uiMeARX7dnnw5inJj0tuX7cyLy4FXbRv8avb8bSi+NcFs+PY9ueDWOnSk56muKJ41JPtN6qc+L1FFc++OkwtvpvfhHG9iZGLS69Jz7WKbP2x6MWe3587qP6AGxLPAnpgifikaxlJj01j38+OhMQyZyKgEjmVAREMqciIJI5FQGRzKkIiGRuyowiTKkuuSgOnonXMrAL4tF59SODcSyx9mErpCZSte54mvZ6YrLNyuoVYSzVHm0F64rboyk+MhzGdv1L3CZb/bn2ThRbVurnzso3h6H6jt+c875SE43qTEAkcyoCIplTERDJnIqASOZUBEQypyIgkrmJrEW4EfgIcMjdryruuxv4FHB2psS73P3R8XZWtkUoIpMz2Rbht4Gbxrj/6+6+pvgatwCIyNQ0bhFw9y1Ae5d5FZG2mcw1gdvNbLuZbTSzBU3LSETaqmwR+CZwGbAGOAB8NXqgma0zswEzGxgm/oiviHRGqSLg7gfdvebudeBbQDhflbtvcPd+d+/vIv4cvIh0RqkiYGZ9o779KBCvTyUiU9q4E42a2feAG4BFZrYP+CJwg5mtARzYA8QzQorIlDZuEXD328a4+94W5CIiHaBPDIpkTkVAJHMqAiKZUxEQyZyKgEjmpsxahCnVBYlPJc+IX0JqfbzKksVhbOTFPRNJqy1m9L0pjPmcWXHswKEwVj8ZT1DaEpV4Tcjq4ngNyvrhckNWfGSk1PPKqsybF8asO55ktX6s3IS2Pvx6qedFdCYgkjkVAZHMqQiIZE5FQCRzKgIimVMREMnclGkRVhfHLbvaK6+EserqlfFG++L208j2eD03678qjPlAuVHT1ctXhbHac7vjJ1biOu2zE+sUJtqApXNJqLzjijiX1LFOtHhTsfrQUBirvuWyMFZ7/oUwllK9cGG8zUQrc8aiS8KYnYlbfan1MFMtyTLraOpMQCRzKgIimVMREMmcioBI5lQERDKnIiCSuSnTIky2ARPtw5FFcbvk+Mp4lN38mc1vA6akWm+plt2Zvt44tqArjM070/w2YEqqDZhqH56+aE4Y6z4Uv/bq0HAYK9sGTEm1AZPtw/lz4+cNlVuHo0wbMEVnAiKZUxEQyZyKgEjmVAREMqciIJI5FQGRzE1kLcLlwHeAJTTWHtzg7uvNbCFwP7CCxnqEt7r70dKJLL04jJ1859IwNrgqbpP17o0nnLTX41g1kUvKyP7fh7EZK1eEsVMr5oexwVXxRJXdx+ph7HRim7OH41xSUhOwVt+6OoydWha3cWs98f9Dxy+NJ5iddTh+7fOq5f5vq+3cFcaqSy4KYzZndhyr1cLY6bfFv9c9Lx2LcxmMW99hHq/Gf+oTOVojwOfd/UrgeuCzZnYlcCew2d1XA5uL70Vkmhm3CLj7AXd/qrh9AtgJLAVuBjYVD9sE3NKiHEWkhc7pvMnMVgBXA1uBJe5+oAi9TOPtgohMMxMuAmY2F3gIuMPdj4+OubvTuF4w1vPWmdmAmQ0MU+5jkiLSOhMqAmbWRaMA3OfuPyzuPmhmfUW8DxhzyRt33+Du/e7e30U8HZaIdMa4RcDMDLgX2OnuXxsVegRYW9xeCzzc/PREpNUmMorwPcAngKfNbFtx313Al4EHzOyTwF7g1skkUuuLR2J1P/pEGOtLjFCz0/FEjrVdL8bPKznpaYp3xYd65k8GwtibDsSvr3L4eBhLtSu9BRONUotbdqmfX6WnJ4z1JkbnJV9fyUlPU/zU6TBWOxiv+5iaFHTWicTr++3eMFZm0lP3uCU+bhFw958DFoQ/MN7zRWRq0ycGRTKnIiCSORUBkcypCIhkTkVAJHPW+LBfe/TaQr/Ozr2hkBqBVz8Yt+xsdjy6i5G4ZVI7WnowZCmVOfFkm9aXGL12Mm5bjRx4eVI5NVNqTcGUSqoV9srh+In1eOReK1hXPNIzpXJB3D70kusURrb6Zo77kTG7fDoTEMmcioBI5lQERDKnIiCSORUBkcypCIhkbsqsRZiSmuAy6eTJpubRKvVUnrt/275EWsQT7diU1Oi8qcSH43ZeSu3VRJuzjXQmIJI5FQGRzKkIiGRORUAkcyoCIplTERDJnIqASOZUBEQypyIgkjkVAZHMqQiIZE5FQCRzKgIimRt3FKGZLQe+Ayyhsfz4Bndfb2Z3A58Czs70eZe7P9qSLC1aBS2t0h2vgpwa2VZ21Fu7pSa49Fpiss02T8SZ/PmlJrqtVBPPi9c+TG6zFVJ5pp42syuM1V8fjp/Y5J/fRIYSjwCfd/enzGwe8KSZPVbEvu7uX2lqRiLSVhNZkPQAcKC4fcLMdgJLW52YiLTHOV0TMLMVwNXA1uKu281su5ltNLMFzU5ORFpvwkXAzOYCDwF3uPtx4JvAZcAaGmcKXw2et87MBsxsYJgzk89YRJpqQkXAzLpoFID73P2HAO5+0N1r7l4HvgVcO9Zz3X2Du/e7e38X8YU6EemMcYuAmRlwL7DT3b826v6+UQ/7KLCj+emJSKtNpDvwHuATwNNmtq247y7gNjNbQ6NtuAf49GQSqS6J19xLTjiZas90xS0YHxqKc5l/QZzL4LF4fwmWaleeSbxNSrTXrBrX8NTkl9UF8eWbsuswVlPrBh4+Ej8x9foqccxH4jZgsnVaclLQZJsz0bJLrlOY+P0k8ftZmRevYVhmncKJdAd+Dox1BFrzmQARaSt9YlAkcyoCIplTERDJnIqASOZUBEQyN2XWIky1AVPtQ6sk6lg1bh+m2mtl24ApqTZgqn1oidZUZcH8+HlD8f7KtgFTUm3AVPsw1Xrz03GbDCvXHi0tNTIx1T5MtDkrc+eUSqVMGzBFZwIimVMREMmcioBI5lQERDKnIiCSORUBkcxNmRZhqg1I79wwVJ/TEz+vHk9GaRfE26wmWnYpyTZnYmSiJWI+K5HL0eNxLDGJZSqXlFTrtNrbGz8xMdLTEnlabzxaLtWy88PlWqC14/HxTI0GrFwQ52mJNqAfGYyTGY4nGk2OTAy3l2hVnvvWROR8oiIgkjkVAZHMqQiIZE5FQCRzKgIimZsyLUISo+xqu+LWW3LSxZOn4v0lJocsPelpSmLU28ie34WxSk/cAq2lJihNtNBKT3qaUE+M+PNU621G/CuYXE8x8fpKT3qaYD3xMau9ejiMVc7EIxo9MdIzORKyzNqOieOlMwGRzKkIiGRORUAkcyoCIplTERDJnIqASObGbRGaWQ+wBeguHv+gu3/RzC4Fvg9cCDwJfMLdS8/wmBqhlmojpUZb4fEowlSbpXQbMCE5uWdilJ3XEq8hNfllQtk2YHKbJVtaXk+8hpKTe5ZtA6YkJ/dMrYeZ+P1MtkBT20y0t8uYyJnAGeBGd38nsAa4ycyuB+4Bvu7uq4CjwCebmpmItMW4RcAbXiu+7Sq+HLgReLC4fxNwSysSFJHWmtA1ATOrFsuSHwIeA14ABt19pHjIPmBpSzIUkZaaUBFw95q7rwGWAdcCV0x0B2a2zswGzGxgmOa/FxWRyTmn7oC7DwKPA+8C5pvZ2St2y4D9wXM2uHu/u/d3UW7aLhFpnXGLgJktNrP5xe1ZwAeBnTSKwceKh60FHm5RjiLSQhMZRdgHbDKzKo2i8YC7/9jMngW+b2b/CPwKuLdVSfrISKnYtJFo+XiT20EdkWr1ecnXV7I92hKJn1F9aOr//MYtAu6+Hbh6jPtfpHF9QESmMX1iUCRzKgIimVMREMmcioBI5lQERDJn3sZWi5m9Auwtvl0EvNq2nacpl7Epl7FNx1wucffFYwXaWgT+347NBty9vyM7fwPlMjblMrbzLRe9HRDJnIqASOY6WQQ2dHDfb6RcxqZcxnZe5dKxawIiMjXo7YBI5lQERDKnIiCSORUBkcypCIhk7n8Bbp6Sz9wtfD4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# mat = rgb2gray(mat)\n",
    "normalised_blocks, hog_image = hog(\n",
    "    mat,\n",
    "    orientations=9,\n",
    "    pixels_per_cell=(8, 8),\n",
    "    cells_per_block=(1, 1),\n",
    "    visualize=True,\n",
    "    transform_sqrt=False,\n",
    ")\n",
    "\n",
    "plt.matshow(hog_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "d4c99ddd-def7-447e-9b19-ca80e17a4d97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((144,), (32, 32))"
      ]
     },
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "normalised_blocks.shape, hog_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "bde9b39d-b21f-487d-b33d-732def0501f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.295545853927269"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "normalised_blocks.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d968269-9513-4874-a88e-8c19c672e96c",
   "metadata": {},
   "source": [
    "#### 准备剩余batch数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "77a16002-aa04-4480-ad07-634c31379fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "t2 = unpickle(train2_path)\n",
    "t3 = unpickle(train3_path)\n",
    "t4 = unpickle(train4_path)\n",
    "t5 = unpickle(train5_path)\n",
    "\n",
    "t2_data = t2[b\"data\"]\n",
    "t2_labels = t2[b\"labels\"]\n",
    "\n",
    "t3_data = t3[b\"data\"]\n",
    "t3_labels = t3[b\"labels\"]\n",
    "\n",
    "t4_data = t4[b\"data\"]\n",
    "t4_labels = t4[b\"labels\"]\n",
    "\n",
    "t5_data = t5[b\"data\"]\n",
    "t5_labels = t5[b\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "8deb84a4-df12-4e79-a9ab-4524861af228",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,)"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = t1_labels + t2_labels + t3_labels + t4_labels + t5_labels\n",
    "y_train = np.array(y_train)\n",
    "y_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a52104de-4e3a-41f3-8c20-36a17815feb5",
   "metadata": {},
   "source": [
    "#### 准备测试数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "e479683d-aa9b-424a-8b52-be2d2c287f31",
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = unpickle(test_path)\n",
    "\n",
    "tst_data = tst[b\"data\"]\n",
    "tst_labels = tst[b\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "997071c5-9696-4eea-9590-cb3fe222762a",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test = np.array(tst_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29715925-404d-4a5e-b0e1-25818ad1f4bf",
   "metadata": {},
   "source": [
    "### 使用HOG特征"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f1c7ab4-665d-4fa8-ab02-504571bdaaf9",
   "metadata": {},
   "source": [
    "#### 准备HOG训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56a1317-73ed-4e95-b803-29f2e0ea1249",
   "metadata": {},
   "outputs": [],
   "source": [
    "orientations = 9\n",
    "pixels_per_cell = (8, 8)\n",
    "cells_per_block = (2, 2)\n",
    "feature_vector = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "94e15a15-228b-4f90-8bfc-12e4e53f0d71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 324)"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_hog = np.empty((10000 * 5, 324))\n",
    "\n",
    "idx = 0\n",
    "for batch in [t1_data, t2_data, t3_data, t4_data, t5_data]:\n",
    "    for item in batch:\n",
    "        mat = item.reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "        hog_ = hog(\n",
    "            mat,\n",
    "            orientations=orientations,\n",
    "            pixels_per_cell=pixels_per_cell,\n",
    "            cells_per_block=cells_per_block,\n",
    "            feature_vector=feature_vector,\n",
    "        )\n",
    "        X_train_hog[idx] = hog_\n",
    "        idx += 1\n",
    "\n",
    "X_train_hog.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ef5870e-6305-48e5-8029-e923e167b466",
   "metadata": {},
   "source": [
    "#### 训练集交叉验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "4e13e51c-105b-4a82-9198-d4a09b40667a",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 11\n",
    "knn_hog = KNeighborsClassifier(n_neighbors=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "1b164a22-b0d5-457e-9331-e9a56ff826ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5257, 0.5186, 0.5233, 0.5212, 0.5228])"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = cross_val_score(knn_hog, X_train_hog, y_train, cv=5, scoring=\"accuracy\")\n",
    "scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5daf79df-8346-42c1-87e2-4b635cad1c98",
   "metadata": {},
   "source": [
    "#### HOG特征测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "657e4b43-156b-4684-9e1e-06a893940b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_hog = np.empty((10000, 324))\n",
    "\n",
    "idx = 0\n",
    "for item in tst_data:\n",
    "    mat = item.reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "    hog_ = hog(\n",
    "        mat,\n",
    "        orientations=orientations,\n",
    "        pixels_per_cell=pixels_per_cell,\n",
    "        cells_per_block=cells_per_block,\n",
    "        feature_vector=feature_vector,\n",
    "    )\n",
    "    X_test_hog[idx] = hog_\n",
    "    idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "39418647-d175-4e75-8fdb-9f9a32f80942",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5255"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_hog.fit(X_train_hog, y_train)\n",
    "knn_hog.score(X_test_hog, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "894523a9-440a-4e56-8ad2-4a3afed76ddc",
   "metadata": {},
   "source": [
    "### 使用灰度图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "84e038bc-7414-4120-b362-8d17b84f5970",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x137ed62e0>"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWzElEQVR4nO3da2xdVXYH8P/CSSAvQhybyCShTtIoFRqYBFkRZdAIBiZQNLxEFYEE4gOajKoBFWn6AVGpUKkfmKq8JCqqUNAwFeXRIUAEUTspQkD4kOAAedOS98t5GPJwQuI4zuqHeyKc9Kz/vT6+D4f9/0lR7LO879l3+y7fe8+6e29zd4jID98Fje6AiNSHkl0kEUp2kUQo2UUSoWQXSYSSXSQRI4bS2MxuAfAcgCYA/+ruT7Kfb2lp8fb29qGcUurs9OnTYayvry+MjRw5ctC319TUFMbMLIzJ97Zt24bu7u7cwSqc7GbWBOCfAfwcwC4An5nZEnffELVpb29HZ2dnbow9CKS2WCIdPXo0jHV3d4exlpaW3OMnT54M24wdOzaMRX88AN7/6HMkP9Q/HvPmzQtjQ3kZPw/AJnff4u4nAbwO4I4h3J6I1NBQkn0KgJ0Dvt+VHRORYajmF+jMbKGZdZpZ54EDB2p9OhEJDCXZdwOYNuD7qdmxs7j7InfvcPeO1tbWIZxORIZiKMn+GYBZZjbdzEYBuAfAkup0S0SqrfDVeHc/ZWYPAfgvlEpvL7v7+qK3d8EFKvkPR999910Y27VrVxhbvz7/oXDkyJGwzfz588NYdHVfKjekOru7LwWwtEp9EZEa0tOpSCKU7CKJULKLJELJLpIIJbtIIoZ0Nb6atPBlbbHxZZNCtm/fHsaWL18exo4fP557fNy4cWEbVpZrbm4OY8wPdcJLEXpmF0mEkl0kEUp2kUQo2UUSoWQXScSwuRqvq6aN09vbG8bY1Xh2ZT2auMKWstqzZ08Yu/zyy8MYm0QVVSFSnHiV3j0WSZSSXSQRSnaRRCjZRRKhZBdJhJJdJBHDpvQm1VFkBxRWDtu8eXMYO3ToUBgbM2ZM7vGenp6wzRdffBHG2LZhbW1tYSxSdGLQ+UzP7CKJULKLJELJLpIIJbtIIpTsIolQsoskYkilNzPbBqAHQD+AU+7eUY1OSfWxUtOOHTvC2JYtW8LY119/HcYmTpyYe3zy5Mlhm61bt4axlStXhrEbb7wxjE2YMCGMpaYadfYb3D0u1IrIsKCX8SKJGGqyO4A/mtkqM1tYjQ6JSG0M9WX8de6+28wuBbDMzL5y948H/kD2R2AhwFcbEZHaGtIzu7vvzv7fD+BtAPNyfmaRu3e4e0dra+tQTiciQ1A42c1srJmNP/M1gPkA1lWrYyJSXUN5GT8ZwNvZDKERAP7d3f+zKr06z9Ri66pqz7zq7+8PY319fWHs6NGjYYwtENnV1ZV7/NSpU2Gbyy67LIyx0htrN2/e/3uxCaA2M9uG+2y5wsnu7lsA/LiKfRGRGlLpTSQRSnaRRCjZRRKhZBdJhJJdJBFacHIQohIbK73VYmFD1q7Ibc6YMSOMXXzxxWGMLR4Z9YPNlBs1alQYO3nyZBj79NNPw9iUKVNyj0+dOjVsw8qUDNs/bjiU5fTMLpIIJbtIIpTsIolQsoskQskukghdjT8Hu3p++vTpQR0H+CSTCy+8MIwVvXpbZPunSZMmhbGbbropjK1atSqMbdq0Kfc4mwjDtpqaPXt2GPvqq6/C2EcffZR7/Pbbbw/bjB07NozVAnv8RIo8PvTMLpIIJbtIIpTsIolQsoskQskukgglu0gifrClt1qsCxeVSFjp57vvvgtjV155ZRi76KKLKu9YBYqOxw033BDGWKns2WefzT1+4sSJsE20bh3Ay2FsDbpoksy0adPCNtdee20YY6VDNhGmqakpjB05ciT3eG9v76D7wSYM6ZldJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUSULb2Z2csAfgFgv7v/KDvWDOANAO0AtgFY4O4Hh9IRNvMnmuHDyklsHbGia4Vt37499/jixYvDNocOHQpje/fuDWM333xzGGOz5YqU2NhYjR8/PowtWLAgjG3YsCH3+JIlS8I2rCy3ZcuWMMbWroseV0uXLg3bTJgwIYw1NzeHMfa7Zg4cOJB7/PDhw2GbqKTLtuuq5Jn9dwBuOefYowA+cPdZAD7IvheRYaxssmf7rX97zuE7ALySff0KgDur2y0Rqbai79knu/uZjzvtRWlHVxEZxoZ8gc5LbxLDN4pmttDMOs2sM3pvIiK1VzTZ95lZGwBk/++PftDdF7l7h7t3tLa2FjydiAxV0WRfAuCB7OsHALxbne6ISK1UUnp7DcD1AFrMbBeAxwE8CeBNM3sQwHYAcQ2mhr755psw9u23515T/B6bgcRmXkUzqNjCi6wfbFbT3Llzw9jkyfElkhEj8n+lrIzD+jh9+vQwxmaOPfzww7nHo/IlAKxcuTKMsdlmW7duDWNRmXXt2rVhG/b4uOaaa8IYG2NWEotKn2wGW5E2ZZPd3e8NQjeWaysiw4c+QSeSCCW7SCKU7CKJULKLJELJLpKIYbPgJJt5Fc1EY7OMPvzwwzDGyj8HD8aT96LSyujRo8M2bKFE9olC1v+ZM2eGsWihyh07doRtWLnm2LFjYSxaKBGIS4AdHR1hG1YOYwt3snGMxqOlpSVsw0qpbIYd2zOPzeiLsBmMUdmWtdEzu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJqGvp7fjx41i9enVubOTIkWG7qMzAZmuxGUg7d+4MY5deemkYGzNmTO5xNguNzciKxgIA3n///TDGFkSMSl5shh0rC73zzjthjC3cefnll+cev+SSS8I2rCy3fPnyMMbKtrt37w5jEdbHFStWhDFWgmWz9qJFMfv6+sI2UUmUPe71zC6SCCW7SCKU7CKJULKLJELJLpKIul6N7+npwSeffBLGItFV8LvvvjtswyYEsIkO7Ep3dHW0ra0tbMMqBmwiycaNG8NYNB5A3P+JEycWuj12ZZqNVdRu3LhxYZv58+eHse7u7jDGqhrRVfD9+8MFkelkF3aFnFUF2JZd0fizyUvRVXdWddEzu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJqGT7p5cB/ALAfnf/UXbsCQC/BHBm8a/H3H1pudvq7e3Fpk2bcmOsFDJ79uzc42x9t127doWxzZs3hzFWGorWamPro7HyGptIEt1nAJgyZUoYi0ps+/btG3QbgG/xxCZdROsGsvvMJhTddtttYYxtAxY9Dth2TKzM19zcHMbYfWNr3kWTr9gkqugxF02qASp7Zv8dgFtyjj/j7nOyf2UTXUQaq2yyu/vHAOJPhojIeWEo79kfMrM1ZvaymcWvA0VkWCia7C8AmAlgDoAuAE9FP2hmC82s08w6jx8/XvB0IjJUhZLd3fe5e7+7nwbwIoB55GcXuXuHu3ewlTxEpLYKJbuZDZz5cReAddXpjojUSiWlt9cAXA+gxcx2AXgcwPVmNgeAA9gG4FeVnKy/vz8s17CX+NErAlb6YWULVmpi/YhmFLEZaqwEyGZQ3XfffWGMlfqWLVuWe3zduvjvMSsnRaVSAJg6dWoYi2abRWvkAUBra2sYu+qqq8LYvffeG8aef/753ONsDFnprampKYwxbI3FqIQcbV0FxOW6PXv2hG3KJru7543kS+Xaicjwok/QiSRCyS6SCCW7SCKU7CKJULKLJMLYwozV1tzc7DfffHNu7ODBg2G7aObVjBkzwjbvvffe4DqXYbOG2AKAEbYIISsB3nTTTWGMzaRbv3597nE2q5CVk1hZjs0QPHToUO5xdp/ZFklsEciLL744jK1duzb3OCtRsX6whSNZWZHd72irLPZYjH5nq1atQk9PT+6UQz2ziyRCyS6SCCW7SCKU7CKJULKLJELJLpKIuu71NmLECFqCiESlJrZnG5v1Fi2GCPAST1RaiRaiBPiec1HJBeClGraf16xZs3KPs/vM9qNj48FmHUa/M3YutmAjG0e2T2DUf3YuVkJj5TD2OGDjGO2Lx0qi0XiwvuuZXSQRSnaRRCjZRRKhZBdJhJJdJBF1vxo/adKk3Nj48ePDdtGkCraNU3QegF+JPXbsWBiLJrWw9cXYFdW9e/eGseXLl4cxdt+iCS9suyN2FZn1kYnGmF0tZlUGtsUTqzRE5yt6NZ5hV+rZ4yqqarAto9i5InpmF0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQRlWz/NA3A7wFMRmm7p0Xu/pyZNQN4A0A7SltALXD3eCG5TFSKYuWTqA0rP7CthBhWIom26WHrkrGSEVt3j7VjpaGo/3PmzAnbsPIam7jCyqXR9kqsTMnW62PlQVZGi7ZQYo83NumGlVLZ2nXsvu3bt2/QbaJxZG0qeWY/BeA37n4FgGsA/NrMrgDwKIAP3H0WgA+y70VkmCqb7O7e5e6fZ1/3ANgIYAqAOwC8kv3YKwDurFEfRaQKBvWe3czaAcwFsALAZHfvykJ7UXqZLyLDVMXJbmbjALwF4BF3P2tlAi+9ycl9o2NmC82s08w62Ta5IlJbFSW7mY1EKdFfdffF2eF9ZtaWxdsA5H4o290XuXuHu3eMGTOmGn0WkQLKJruVLlu+BGCjuz89ILQEwAPZ1w8AeLf63RORaqlkes9PANwPYK2ZfZkdewzAkwDeNLMHAWwHsKDcDfX394flpuPHj4ftolLIZZddFrbZuXNnGGOz5Vjprb29Pfc4K7319fWFMYaNByuvRGPV1dWVexzga7ixc40ePTqMnThxIvc4W4uNldDYq0LWx6gfDCvzsfsclfkAvlVWFGOPnWisWNmwbLK7+3IAUVHyxnLtRWR40CfoRBKhZBdJhJJdJBFKdpFEKNlFElHXBSfdPSwnsBlD0SwvVnJhixey8horeXV3d+ceZ7PQipbeWP/Z/Y5KW6z0xm6PYTPzImw8WOmKzZZjj52oFMVmr7HyYNF2rP/R7E12LlamDNsMuoWInJeU7CKJULKLJELJLpIIJbtIIpTsIomoa+kN4OWESLRPGVsMkZXXWNmCxaLFN1gbVtZiYzFhwoRC7aIyYC0WWGQlx2hM2FixcxUdx+i+FVngtFy7IuUwIC69sZIii0X0zC6SCCW7SCKU7CKJULKLJELJLpKIul6NP336NI4cOZIbO3z4cNguasOuuDMTJ04MY2yNsQi7CsvWHmMTP9gVYTapIrpCzrbKYrGiojFhV7OLXn1m1QS2nlyEVRlYjPWD/a6jGPu9FNnWSs/sIolQsoskQskukgglu0gilOwiiVCyiySibOnNzKYB+D1KWzI7gEXu/pyZPQHglwAOZD/6mLsvZbfV19eH/ftz93+ka5NF68KxskqRUgc7FxCXNVg5pqgia/IBwNixY3OPs/IgKxkVmXDBzsdKQwwrs7I+RmXKohNaik4oKjL+bFuxKMbOU8mj9BSA37j752Y2HsAqM1uWxZ5x93+q4DZEpMEq2eutC0BX9nWPmW0EMKXWHROR6hrUe3YzawcwF8CK7NBDZrbGzF42s/hjaSLScBUnu5mNA/AWgEfc/QiAFwDMBDAHpWf+p4J2C82s08w6i3x0UUSqo6JkN7ORKCX6q+6+GADcfZ+797v7aQAvApiX19bdF7l7h7t3sEX0RaS2yia7lS4/vgRgo7s/PeB424AfuwvAuup3T0SqpZKr8T8BcD+AtWb2ZXbsMQD3mtkclMpx2wD8qtwNse2f2Ayf6BUBK3WMGTOmXHdyFVlzrej2SUXLWkVmZbGZcmwci87yirCxYmUtNlORlW1PnDhRWccGYOUr9jhlv0/2FrbI4yfqB7utSq7GLweQ91ugNXURGV70CTqRRCjZRRKhZBdJhJJdJBFKdpFE1HXByREjRqC1tTU3xsouUTmBlVwYNrOtyNY/RcsxRWfLsXZRX4qOVVFF+kHLRuQ+F1mokrUpuv1T0RJmhPUxGis6827QPRCR85KSXSQRSnaRRCjZRRKhZBdJhJJdJBF1Lb01NTVhwoQJuTFWvorKCb29vWGbo0ePhjE2A4zFolJI0ZlcbHYVu002VtFtFhnfcli7IjO52O2xcil7HETY2DNF98yLFgIF+GNusG3Y/dIzu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJqGvpzczCUhQrUUWL9bHFBNkCf2xWE4tFpTc2k4vF2LmK7jcWtWOlsFqU3iJFF+dkY8X27qs2NhOt6Iy4IqLHlWa9iYiSXSQVSnaRRCjZRRKhZBdJRNmr8WZ2EYCPAVyY/fwf3P1xM5sO4HUAkwCsAnC/u9NtWt09vBrLJjNEV9ZZmyK3B/DJDBG2vtjo0aMLtWNX8dkV16j/7EoxixXZDothE1rYfWbViSJbVLFKzrFjx8IYw7YcY/2P7jf7vUSPK1qpCSPf6wXwM3f/MUrbM99iZtcA+C2AZ9z9TwEcBPBgBbclIg1SNtm95Mx80ZHZPwfwMwB/yI6/AuDOWnRQRKqj0v3Zm7IdXPcDWAZgM4BD7n7mEwa7AEypSQ9FpCoqSnZ373f3OQCmApgH4M8qPYGZLTSzTjPrLLJ9rohUx6Cuxrv7IQAfAvhzAJeY2ZkrI1MB7A7aLHL3DnfvqOfHGkXkbGWT3cxazeyS7OvRAH4OYCNKSf+X2Y89AODdGvVRRKqgktpJG4BXzKwJpT8Ob7r7e2a2AcDrZvYPAL4A8FK5G3L3sMzAymHR5IOi2+0UXX8sKkMVnbTCJlWwdcmKrHlXdNINixVZ127UqFGFzsVKdmwco/MVHQ829qwfrATLxiRS5DFcNtndfQ2AuTnHt6D0/l1EzgP6BJ1IIpTsIolQsoskQskukgglu0girNprY9GTmR0AsD37tgVAd91OHlM/zqZ+nO1868efuHtrXqCuyX7Wic063b2jISdXP9SPBPuhl/EiiVCyiySikcm+qIHnHkj9OJv6cbYfTD8a9p5dROpLL+NFEtGQZDezW8zsf8xsk5k92og+ZP3YZmZrzexLM+us43lfNrP9ZrZuwLFmM1tmZl9n/09sUD+eMLPd2Zh8aWa31qEf08zsQzPbYGbrzeyvs+N1HRPSj7qOiZldZGYrzWx11o+/z45PN7MVWd68YWaDmy7n7nX9B6AJpWWtZgAYBWA1gCvq3Y+sL9sAtDTgvD8FcDWAdQOO/SOAR7OvHwXw2wb14wkAf1Pn8WgDcHX29XgA/wvginqPCelHXccEgAEYl309EsAKANcAeBPAPdnxfwHwV4O53UY8s88DsMndt3hp6enXAdzRgH40jLt/DODbcw7fgdLCnUCdFvAM+lF37t7l7p9nX/egtDjKFNR5TEg/6spLqr7IayOSfQqAnQO+b+RilQ7gj2a2yswWNqgPZ0x2967s670AJjewLw+Z2ZrsZX7N304MZGbtKK2fsAINHJNz+gHUeUxqschr6hfornP3qwH8BYBfm9lPG90hoPSXHaU/RI3wAoCZKO0R0AXgqXqd2MzGAXgLwCPufmRgrJ5jktOPuo+JD2GR10gjkn03gGkDvg8Xq6w1d9+d/b8fwNto7Mo7+8ysDQCy//c3ohPuvi97oJ0G8CLqNCZmNhKlBHvV3Rdnh+s+Jnn9aNSYZOc+hEEu8hppRLJ/BmBWdmVxFIB7ACypdyfMbKyZjT/zNYD5ANbxVjW1BKWFO4EGLuB5Jrkyd6EOY2KlhfNeArDR3Z8eEKrrmET9qPeY1GyR13pdYTznauOtKF3p3AzgbxvUhxkoVQJWA1hfz34AeA2ll4N9KL33ehClPfM+APA1gP8G0NygfvwbgLUA1qCUbG116Md1KL1EXwPgy+zfrfUeE9KPuo4JgKtQWsR1DUp/WP5uwGN2JYBNAP4DwIWDuV19gk4kEalfoBNJhpJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUS8X9e50I7dqGiswAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "mat = t1_data[2].reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "gray = rgb2gray(mat)\n",
    "plt.imshow(gray, cmap=\"gray\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "64f5f208-1dec-4864-a0c3-7ce001d29c04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.        , 0.98840429, 0.98889725, ..., 0.11760922, 0.15282791,\n",
       "        0.23513757]])"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MinMaxScaler().fit_transform(gray).reshape(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2418319b-fd5b-4779-a297-aec8d778d7db",
   "metadata": {},
   "source": [
    "#### 准备灰度图训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "7b37ac5a-3936-43dd-87e5-ac5306b0a17b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 1024)"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_gray = np.empty((10000 * 5, 32 * 32))\n",
    "\n",
    "idx = 0\n",
    "for batch in [t1_data, t2_data, t3_data, t4_data, t5_data]:\n",
    "    for item in batch:\n",
    "        mat = item.reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "        gray = np.array(rgb2gray(mat))\n",
    "        X_train_gray[idx] = MinMaxScaler().fit_transform(gray).reshape(1, -1)\n",
    "        idx += 1\n",
    "\n",
    "X_train_gray.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70cae9e2-576b-44ba-8cf3-467a10852633",
   "metadata": {},
   "source": [
    "#### 训练集交叉验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "1dbb6187-97cb-4130-a7c7-8d2c45614f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 11\n",
    "knn_gray = KNeighborsClassifier(n_neighbors=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "456f7fb9-e5c2-4ebd-82b6-e238f2e7f28c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.3357, 0.3319, 0.3358, 0.3272, 0.3309])"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = cross_val_score(knn_gray, X_train_gray, y_train, cv=5, scoring=\"accuracy\")\n",
    "scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a01339b3-3e5c-47ab-a8b9-af70546d830d",
   "metadata": {},
   "source": [
    "#### 灰度图测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "624cb29e-37c2-43ae-8a35-de6c47bbe703",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 1024)"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_gray = np.empty((10000, 1024))\n",
    "\n",
    "idx = 0\n",
    "for item in tst_data:\n",
    "    mat = item.reshape(3, 32, 32).transpose(1, 2, 0)\n",
    "    gray = np.array(rgb2gray(mat))\n",
    "    X_test_gray[idx] = MinMaxScaler().fit_transform(gray).reshape(1, -1)\n",
    "    idx += 1\n",
    "\n",
    "X_test_gray.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "7df06dfd-b128-4245-a5f5-7f33a20cd24b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3402"
      ]
     },
     "execution_count": 128,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_gray.fit(X_train_gray, y_train)\n",
    "knn_gray.score(X_test_gray, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd5d2726-b2d4-4d45-a5df-6993c71a6f3b",
   "metadata": {},
   "source": [
    "### 尝试SGDClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "6cbedaec-a5e7-4c0c-a6f7-7487d06818d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5211"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_hog = SGDClassifier(loss=\"log\")\n",
    "sgd_hog.fit(X_train_hog, y_train)\n",
    "sgd_hog.score(X_test_hog, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "83e8ea6f-7268-4fd5-b4c2-a880cf94eca4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2332"
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_gray = SGDClassifier(loss=\"log\")\n",
    "sgd_gray.fit(X_train_gray, y_train)\n",
    "sgd_gray.score(X_test_gray, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
